#
# Copyright (C) 2017-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Generate lookup tables for cosmological priors.

This is a lookup table for the BAYESTAR uniform-in-comoving-volume prior. The
BAYESTAR posterior distribution has a builtin prior that is uniform in naively
Euclidean luminosity distance space (proportional to DL^2). When we want to
switch to a prior that is uniform in comoving volume, we have to multiply by
the factor dVC/dVL, the differential comoving volume per unit Euclidean
luminosity volume.

The lookup table is in terms of the *natural logarithm* of the luminosity
distance and provides the *natural logarithm* of dVC/dVL. It consists of the
following constants:

    dVC_dVL_tmin, dVC_dVL_tmax, dVC_dVL_dt
        Minimum, maximum, and step size for a uniform grid in t = log(DL),
        such that t[i] = dVC_dVL_tmin + i * dVC_dVL_dt.

    dVC_dVL_data
        Tabulated values of y = f(t) = log(dVC/dVL), such that y[i] = f(t[i]).

    dVC_dVL_high_z_slope, dVC_dVL_high_z_intercept
        Linear fit for extrapolating f(t) for t>dVC_dVL_tmax.
"""
import argparse
import os
import astropy.cosmology
import astropy.units as u
import numpy as np
import scipy.misc

from ligo.skymap.postprocess.cosmology import z_for_DL, dVC_dVL_for_DL

parser = argparse.ArgumentParser()
parser.add_argument(
    'cosmology', choices=astropy.cosmology.parameters.available,
    default='Planck15', nargs='?', help='Cosmological model')
args = parser.parse_args()

cosmo = astropy.cosmology.default_cosmology.get_cosmology_from_string(
    args.cosmology)

DL = np.logspace(0, 6, 32)
log_DL = np.log(DL)
dVC_dVL = dVC_dVL_for_DL(DL)
log_dVC_dVL = np.log(dVC_dVL)


def func(x):
    return np.log(dVC_dVL_for_DL(np.exp(x)))


high_z_x0 = np.log(1e6)
high_z_y0 = func(high_z_x0)
high_z_slope = scipy.misc.derivative(func, high_z_x0)
high_z_intercept = high_z_y0 - high_z_slope * high_z_x0

filename = os.path.basename(__file__)
print('/* DO NOT EDIT. Automatically generated by', filename, '*/')
print('static const double dVC_dVL_data[] = {')
print(*('\t{:+.8e}'.format(c) for c in log_dVC_dVL), sep=',\n')
print('};')
print('static const double dVC_dVL_tmin = {:.15f};'.format(log_DL[0]))
print('static const double dVC_dVL_tmax = {:.15f};'.format(log_DL[-1]))
print('static const double dVC_dVL_dt = {:.15f};'.format(
      np.diff(log_DL)[0]))
print('static const double dVC_dVL_high_z_slope = {:.15f};'.format(
      high_z_slope))
print('static const double dVC_dVL_high_z_intercept = {:.15f};'.format(
      high_z_intercept))


def exact_dVC_dVL(DL):
    """An alternate expression of dVC_dVL_for_DL."""
    z = z_for_DL(DL)
    DL = DL * u.Mpc
    DH = cosmo.hubble_distance
    DC = cosmo.comoving_distance(z)
    DM = cosmo.comoving_transverse_distance(z)
    dVC_dz = cosmo.differential_comoving_volume(z)

    Ok0 = cosmo.Ok0
    if Ok0 == 0.0:
        dDM_dDC = 1.0
    elif Ok0 > 0.0:
        dDM_dDC = np.cosh(np.sqrt(Ok0) * DC / DH)
    else:  # Ok0 < 0.0 or Ok0 is nan
        dDM_dDC = np.cos(np.sqrt(-Ok0) * DC / DH)

    dDC_dz = DH * cosmo.inv_efunc(z)
    dDL_dz = DM + (1 + z) * dDM_dDC * dDC_dz
    dVL_dz = np.square(DL) * dDL_dz / u.sr

    return dVC_dz / dVL_dz


DL = np.logspace(-2, 6, 1000)
dVC_dVL = exact_dVC_dVL(DL)

print('static const double dVC_dVL_test_x[] = {')
print(*('\t{:+.8e}'.format(c) for c in DL), sep=',\n')
print('};')
print('static const double dVC_dVL_test_y[] = {')
print(*('\t{:+.8e}'.format(c) for c in dVC_dVL), sep=',\n')
print('};')
